{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "727cb0f0c9384346833b39c0817e44d8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Image(value=b'', height='224', layout=\"Layout(width='100%')\", width='224'), Checkbox(value=True…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\"\"\"\n",
    "Important Stuff\n",
    "\"\"\"\n",
    "\n",
    "# Webcam index, this should be 0 for most laptops\n",
    "DEVICE = 0\n",
    "\n",
    "# The object detection network weights file\n",
    "WEIGHTS='weights/squirrel_detector.hdf5'\n",
    "\n",
    "\"\"\"\n",
    "User Interface Stuff (You can ignore this)\n",
    "\"\"\"\n",
    "import IPython\n",
    "import ipywidgets as widgets\n",
    "from IPython.display import display\n",
    "\n",
    "ipython = IPython.get_ipython()\n",
    "\n",
    "style = {'description_width': 'initial'}\n",
    "\n",
    "w_image = widgets.Image(width=224, height=224, format='png',\n",
    "                        layout=widgets.Layout(width='100%'))\n",
    "\n",
    "w_heatmap = widgets.Checkbox(\n",
    "    value=True,\n",
    "    description='Overlay heatmap on image',\n",
    "    style=style\n",
    ")\n",
    "\n",
    "w_bbox = widgets.Checkbox(\n",
    "    value=True,\n",
    "    description='Show Bounding Boxes (BBoxes)',\n",
    "    style=style\n",
    ")\n",
    "w_bbox_thresh = widgets.FloatSlider(min=0, max=1, value=0.99, step=0.01, \n",
    "                             description='BBox Confidence Threshold',\n",
    "                             layout=widgets.Layout(width='100%'),\n",
    "                             style=style)\n",
    "\n",
    "w_merge = widgets.Checkbox(\n",
    "    value=True,\n",
    "    description='Merge close BBoxes',\n",
    "    style=style\n",
    ")\n",
    "w_merge_thresh = widgets.FloatSlider(min=0, max=1, value=0.75, step=0.01, \n",
    "                             description='BBox Merge Threshold',\n",
    "                             layout=widgets.Layout(width='100%'),\n",
    "                             style=style)\n",
    "\n",
    "vbox = widgets.VBox([w_image, w_heatmap, w_bbox, w_bbox_thresh, w_merge, w_merge_thresh])\n",
    "\n",
    "display(vbox)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W0830 13:57:23.088265 4690777536 deprecation.py:506] From /Users/carroll/anaconda3/envs/keras/lib/python3.7/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
      "W0830 13:57:27.219778 4690777536 deprecation_wrapper.py:119] From /Users/carroll/Git/keras-mobile-detectnet/model.py:21: The name tf.keras.backend.get_session is deprecated. Please use tf.compat.v1.keras.backend.get_session instead.\n",
      "\n",
      "W0830 13:57:27.220716 4690777536 deprecation_wrapper.py:119] From /Users/carroll/Git/keras-mobile-detectnet/model.py:22: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
      "\n",
      "W0830 13:57:27.527359 4690777536 deprecation.py:323] From /Users/carroll/Git/keras-mobile-detectnet/model.py:27: convert_variables_to_constants (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.convert_variables_to_constants`\n",
      "W0830 13:57:27.528151 4690777536 deprecation.py:323] From /Users/carroll/anaconda3/envs/keras/lib/python3.7/site-packages/tensorflow/python/framework/graph_util_impl.py:270: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.extract_sub_graph`\n",
      "W0830 13:57:27.949882 4690777536 deprecation.py:323] From /Users/carroll/Git/keras-mobile-detectnet/model.py:28: remove_training_nodes (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.remove_training_nodes`\n",
      "W0830 13:57:28.150229 4690777536 deprecation_wrapper.py:119] From /Users/carroll/Git/keras-mobile-detectnet/model.py:47: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.\n",
      "\n",
      "W0830 13:57:28.150995 4690777536 deprecation_wrapper.py:119] From /Users/carroll/Git/keras-mobile-detectnet/model.py:48: The name tf.GPUOptions is deprecated. Please use tf.compat.v1.GPUOptions instead.\n",
      "\n",
      "W0830 13:57:28.151535 4690777536 deprecation_wrapper.py:119] From /Users/carroll/Git/keras-mobile-detectnet/model.py:51: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Load the model\n",
    "from model import MobileDetectNetModel\n",
    "import numpy as np\n",
    "import cv2  \n",
    "\n",
    "keras_model = MobileDetectNetModel.complete_model()\n",
    "keras_model.load_weights(WEIGHTS, by_name=True)\n",
    "\n",
    "tf_engine = keras_model.tf_engine()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "cap = cv2.VideoCapture(0)\n",
    "\n",
    "try:\n",
    "    while True:\n",
    "        # Have to call this to get update values from sliders / dropdowns\n",
    "        ipython.kernel.do_one_iteration()\n",
    "        \n",
    "        # Read the frame from the camera\n",
    "        ret, frame = cap.read()\n",
    "                \n",
    "        img_original = frame\n",
    "        img_draw = img_original.copy()\n",
    "        img_resize = cv2.resize(img_original, (224, 224))\n",
    "\n",
    "        scale_width = img_original.shape[1] / 224\n",
    "        scale_height = img_original.shape[0] / 224\n",
    "        \n",
    "        \"\"\"\n",
    "        The network expects the image to be scaled between -1 and 1,\n",
    "        but most images are scaled between 0 and 255 normally.\n",
    "        \n",
    "        We divide by 127.5 to scale between 0 and 2, and subtract one to\n",
    "        be between -1 and 1\n",
    "        \"\"\"\n",
    "        img_input = (img_resize / 127.5) - 1\n",
    "        \n",
    "        \"\"\"\n",
    "        The neural network expects a \"batch\" of images as an input\n",
    "        This converts our single image with a shape of (224, 224, 3) to (1, 224, 224, 3)\n",
    "        The 1 at the beginning is called the batch dimension\n",
    "        \"\"\"\n",
    "        batch = np.expand_dims(img_input, axis=0)\n",
    "\n",
    "        # Do the actual inference        \n",
    "        bboxes, classes = tf_engine.infer(batch)\n",
    "        \n",
    "        rectangles = []\n",
    "        for y in range(0, 7):\n",
    "            for x in range(0, 7):\n",
    "\n",
    "                if classes[0, y, x, 0] >= w_bbox_thresh.value:\n",
    "                    rect = [\n",
    "                        int(bboxes[0, int(y), int(x), 0] * 224),\n",
    "                        int(bboxes[0, int(y), int(x), 1] * 224),\n",
    "                        int(bboxes[0, int(y), int(x), 2] * 224),\n",
    "                        int(bboxes[0, int(y), int(x), 3] * 224)]\n",
    "                    rectangles.append(rect)\n",
    "\n",
    "        if w_heatmap.value:\n",
    "            cls_img = cv2.resize((classes[0]*255).astype(np.uint8), (img_draw.shape[1], img_draw.shape[0]), interpolation=cv2.INTER_AREA)\n",
    "            cls_cmap = cv2.applyColorMap(cls_img, cv2.COLORMAP_JET)\n",
    "            cls_add = (img_draw).astype(np.float32) + (np.expand_dims(cls_img, axis=-1)*cls_cmap).astype(np.float32)\n",
    "            img_draw = (255*(cls_add / np.max(cls_add))).astype(np.uint8)\n",
    "        \n",
    "        if w_merge.value:\n",
    "            rectangles, merges = cv2.groupRectangles(rectangles, 1, eps=w_bbox_thresh.value)\n",
    "\n",
    "        if w_bbox.value:\n",
    "            for rect in rectangles:\n",
    "                cv2.rectangle(img_draw,\n",
    "                              (int(rect[0]*scale_width), int(rect[1]*scale_height)),\n",
    "                              (int(rect[2]*scale_width), int(rect[3]*scale_height)),\n",
    "                              (0, 255, 0), 5)\n",
    "\n",
    "        \n",
    "        # Visualization Code\n",
    "        result, img_png = cv2.imencode('.png', img_draw)\n",
    "        w_image.value = img_png.tobytes()\n",
    "            \n",
    "        \n",
    "except KeyboardInterrupt:\n",
    "    pass\n",
    "finally:\n",
    "    cap.release()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
